{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63ff8110-1d09-4aa1-b296-0c038caac7f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightning typing_extensions \"pydantic<2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b847aa6-10dd-409f-8166-7a6c6104baff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from torch import optim, nn, utils, Tensor\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import ToTensor\n",
    "import lightning.pytorch as pl\n",
    "\n",
    "# define any number of nn.Modules (or use your current ones)\n",
    "encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))\n",
    "decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))\n",
    "\n",
    "\n",
    "# define the LightningModule\n",
    "class LitAutoEncoder(pl.LightningModule):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # training_step defines the train loop.\n",
    "        # it is independent of forward\n",
    "        x, y = batch\n",
    "        x = x.view(x.size(0), -1)\n",
    "        z = self.encoder(x)\n",
    "        x_hat = self.decoder(z)\n",
    "        loss = nn.functional.mse_loss(x_hat, x)\n",
    "        # Logging to TensorBoard (if installed) by default\n",
    "        self.log(\"train_loss\", loss)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = optim.Adam(self.parameters(), lr=1e-3)\n",
    "        return optimizer\n",
    "\n",
    "\n",
    "# init the autoencoder\n",
    "autoencoder = LitAutoEncoder(encoder, decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc49795f-6333-4bb1-ac9c-168fc511cb91",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())\n",
    "train_loader = utils.data.DataLoader(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f645b81-9703-4324-93ea-297041afb7d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from determined.experimental import core_v2\n",
    "from determined.lightning.experimental import DetLogger\n",
    "\n",
    "det_logger = DetLogger(defaults=core_v2.DefaultConfig(name=\"ptl-demo\"))\n",
    "trainer = pl.Trainer(limit_train_batches=100, max_epochs=1, logger=det_logger)\n",
    "trainer.fit(model=autoencoder, train_dataloaders=train_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
